{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 402 RNN\n",
    "\n",
    "View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n",
    "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n",
    "\n",
    "Dependencies:\n",
    "* torch: 0.1.11\n",
    "* matplotlib\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f5864059930>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(1)    # reproducible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Hyper Parameters\n",
    "EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch\n",
    "BATCH_SIZE = 64\n",
    "TIME_STEP = 28          # rnn time step / image height\n",
    "INPUT_SIZE = 28         # rnn input size / image width\n",
    "LR = 0.01               # learning rate\n",
    "DOWNLOAD_MNIST = True   # set to True if haven't download the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Mnist digital dataset\n",
    "train_data = dsets.MNIST(\n",
    "    root='./mnist/',\n",
    "    train=True,                         # this is training data\n",
    "    transform=transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to\n",
    "                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]\n",
    "    download=DOWNLOAD_MNIST,            # download it if you don't have it\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([60000, 28, 28])\n",
      "torch.Size([60000])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADrpJREFUeJzt3X2sVHV+x/HPp6hpxAekpkhYLYsxGDWWbRAbQ1aNYX2I\nRlFjltSERiP7hyRu0pAa+sdqWqypD81SzQY26kKzdd1EjehufKiobGtCvCIq4qKu0SzkCjWIAj5Q\nuN/+cYftXb3zm8vMmTnD/b5fyeTOnO+cOd+c8OE8zvwcEQKQz5/U3QCAehB+ICnCDyRF+IGkCD+Q\nFOEHkiL8QFKEH6Oy/aLtL23vaTy21N0TqkX4UbI4Io5pPGbW3QyqRfiBpAg/Sv7Z9se2/9v2BXU3\ng2qZe/sxGtvnStosaZ+k70u6T9KsiPhdrY2hMoQfY2L7aUm/ioh/q7sXVIPdfoxVSHLdTaA6hB/f\nYHuS7Ytt/6ntI2z/jaTvSnq67t5QnSPqbgB96UhJ/yTpdEkHJP1W0lUR8U6tXaFSHPMDSbHbDyRF\n+IGkCD+QFOEHkurp2X7bnF0EuiwixnQ/RkdbftuX2N5i+z3bt3byWQB6q+1LfbYnSHpH0jxJWyW9\nImlBRGwuzMOWH+iyXmz550h6LyLej4h9kn4h6coOPg9AD3US/mmSfj/i9dbGtD9ie5HtAdsDHSwL\nQMW6fsIvIlZKWimx2w/0k062/NsknTzi9bca0wAcBjoJ/yuSTrP9bdtHafgHH9ZU0xaAbmt7tz8i\n9tteLOkZSRMkPRgRb1XWGYCu6um3+jjmB7qvJzf5ADh8EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQf\nSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKE\nH0iK8ANJEX4gKcIPJEX4gaQIP5BU20N04/AwYcKEYv3444/v6vIXL17ctHb00UcX5505c2axfvPN\nNxfrd999d9PaggULivN++eWXxfqdd95ZrN9+++3Fej/oKPy2P5C0W9IBSfsjYnYVTQHoviq2/BdG\nxMcVfA6AHuKYH0iq0/CHpGdtv2p70WhvsL3I9oDtgQ6XBaBCne72z42Ibbb/XNJztn8bEetGviEi\nVkpaKUm2o8PlAahIR1v+iNjW+LtD0uOS5lTRFIDuazv8tifaPvbgc0nfk7SpqsYAdFcnu/1TJD1u\n++Dn/EdEPF1JV+PMKaecUqwfddRRxfp5551XrM+dO7dpbdKkScV5r7nmmmK9Tlu3bi3Wly9fXqzP\nnz+/aW337t3FeV9//fVi/aWXXirWDwdthz8i3pf0lxX2AqCHuNQHJEX4gaQIP5AU4QeSIvxAUo7o\n3U134/UOv1mzZhXra9euLda7/bXafjU0NFSs33DDDcX6nj172l724OBgsf7JJ58U61u2bGl72d0W\nER7L+9jyA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSXOevwOTJk4v19evXF+szZsyosp1Ktep9165d\nxfqFF17YtLZv377ivFnvf+gU1/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFIM0V2BnTt3FutLliwp\n1i+//PJi/bXXXivWW/2EdcnGjRuL9Xnz5hXre/fuLdbPPPPMprVbbrmlOC+6iy0/kBThB5Ii/EBS\nhB9IivADSRF+ICnCDyTF9/n7wHHHHVestxpOesWKFU1rN954Y3He66+/vlh/+OGHi3X0n8q+z2/7\nQds7bG8aMW2y7edsv9v4e0InzQLovbHs9v9M0iVfm3arpOcj4jRJzzdeAziMtAx/RKyT9PX7V6+U\ntKrxfJWkqyruC0CXtXtv/5SIODjY2UeSpjR7o+1Fkha1uRwAXdLxF3siIkon8iJipaSVEif8gH7S\n7qW+7banSlLj747qWgLQC+2Gf42khY3nCyU9UU07AHql5W6/7YclXSDpRNtbJf1I0p2Sfmn7Rkkf\nSrqum02Od5999llH83/66adtz3vTTTcV64888kixPjQ01PayUa+W4Y+IBU1KF1XcC4Ae4vZeICnC\nDyRF+IGkCD+QFOEHkuIrvePAxIkTm9aefPLJ4rznn39+sX7ppZcW688++2yxjt5jiG4ARYQfSIrw\nA0kRfiApwg8kRfiBpAg/kBTX+ce5U089tVjfsGFDsb5r165i/YUXXijWBwYGmtbuv//+4ry9/Lc5\nnnCdH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8kxXX+5ObPn1+sP/TQQ8X6scce2/ayly5dWqyvXr26\nWB8cHCzWs+I6P4Aiwg8kRfiBpAg/kBThB5Ii/EBShB9Iiuv8KDrrrLOK9XvvvbdYv+ii9gdzXrFi\nRbG+bNmyYn3btm1tL/twVtl1ftsP2t5he9OIabfZ3mZ7Y+NxWSfNAui9sez2/0zSJaNM/9eImNV4\n/LratgB0W8vwR8Q6STt70AuAHurkhN9i2280DgtOaPYm24tsD9hu/mNuAHqu3fD/RNKpkmZJGpR0\nT7M3RsTKiJgdEbPbXBaALmgr/BGxPSIORMSQpJ9KmlNtWwC6ra3w25464uV8SZuavRdAf2p5nd/2\nw5IukHSipO2SftR4PUtSSPpA0g8iouWXq7nOP/5MmjSpWL/iiiua1lr9VoBdvly9du3aYn3evHnF\n+ng11uv8R4zhgxaMMvmBQ+4IQF/h9l4gKcIPJEX4gaQIP5AU4QeS4iu9qM1XX31VrB9xRPli1P79\n+4v1iy++uGntxRdfLM57OOOnuwEUEX4gKcIPJEX4gaQIP5AU4QeSIvxAUi2/1Yfczj777GL92muv\nLdbPOeecprVW1/Fb2bx5c7G+bt26jj5/vGPLDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ1/nJs5\nc2axvnjx4mL96quvLtZPOumkQ+5prA4cOFCsDw6Wfy1+aGioynbGHbb8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BUy+v8tk+WtFrSFA0Pyb0yIn5se7KkRyRN1/Aw3ddFxCfdazWvVtfSFywYbSDlYa2u\n40+fPr2dlioxMDBQrC9btqxYX7NmTZXtpDOWLf9+SX8XEWdI+mtJN9s+Q9Ktkp6PiNMkPd94DeAw\n0TL8ETEYERsaz3dLelvSNElXSlrVeNsqSVd1q0kA1TukY37b0yV9R9J6SVMi4uD9lR9p+LAAwGFi\nzPf22z5G0qOSfhgRn9n/PxxYRESzcfhsL5K0qNNGAVRrTFt+20dqOPg/j4jHGpO3257aqE+VtGO0\neSNiZUTMjojZVTQMoBotw+/hTfwDkt6OiHtHlNZIWth4vlDSE9W3B6BbWg7RbXuupN9IelPSwe9I\nLtXwcf8vJZ0i6UMNX+rb2eKzUg7RPWVK+XTIGWecUazfd999xfrpp59+yD1VZf369cX6XXfd1bT2\nxBPl7QVfyW3PWIfobnnMHxH/JanZh110KE0B6B/c4QckRfiBpAg/kBThB5Ii/EBShB9Iip/uHqPJ\nkyc3ra1YsaI476xZs4r1GTNmtNVTFV5++eVi/Z577inWn3nmmWL9iy++OOSe0Bts+YGkCD+QFOEH\nkiL8QFKEH0iK8ANJEX4gqTTX+c8999xifcmSJcX6nDlzmtamTZvWVk9V+fzzz5vWli9fXpz3jjvu\nKNb37t3bVk/of2z5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpNNf558+f31G9E5s3by7Wn3rqqWJ9\n//79xXrpO/e7du0qzou82PIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKOiPIb7JMlrZY0RVJIWhkR\nP7Z9m6SbJP1P461LI+LXLT6rvDAAHYsIj+V9Ywn/VElTI2KD7WMlvSrpKknXSdoTEXePtSnCD3Tf\nWMPf8g6/iBiUNNh4vtv225Lq/ekaAB07pGN+29MlfUfS+sakxbbfsP2g7ROazLPI9oDtgY46BVCp\nlrv9f3ijfYyklyQti4jHbE+R9LGGzwP8o4YPDW5o8Rns9gNdVtkxvyTZPlLSU5KeiYh7R6lPl/RU\nRJzV4nMIP9BlYw1/y91+25b0gKS3Rwa/cSLwoPmSNh1qkwDqM5az/XMl/UbSm5KGGpOXSlogaZaG\nd/s/kPSDxsnB0mex5Qe6rNLd/qoQfqD7KtvtBzA+EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrw\nA0kRfiApwg8kRfiBpAg/kBThB5Lq9RDdH0v6cMTrExvT+lG/9tavfUn01q4qe/uLsb6xp9/n/8bC\n7YGImF1bAwX92lu/9iXRW7vq6o3dfiApwg8kVXf4V9a8/JJ+7a1f+5LorV219FbrMT+A+tS95QdQ\nE8IPJFVL+G1fYnuL7fds31pHD83Y/sD2m7Y31j2+YGMMxB22N42YNtn2c7bfbfwddYzEmnq7zfa2\nxrrbaPuymno72fYLtjfbfsv2LY3pta67Ql+1rLeeH/PbniDpHUnzJG2V9IqkBRGxuaeNNGH7A0mz\nI6L2G0Jsf1fSHkmrDw6FZvtfJO2MiDsb/3GeEBF/3ye93aZDHLa9S701G1b+b1XjuqtyuPsq1LHl\nnyPpvYh4PyL2SfqFpCtr6KPvRcQ6STu/NvlKSasaz1dp+B9PzzXprS9ExGBEbGg83y3p4LDyta67\nQl+1qCP80yT9fsTrrapxBYwiJD1r+1Xbi+puZhRTRgyL9pGkKXU2M4qWw7b30teGle+bddfOcPdV\n44TfN82NiL+SdKmkmxu7t30pho/Z+ula7U8knarhMRwHJd1TZzONYeUflfTDiPhsZK3OdTdKX7Ws\ntzrCv03SySNef6sxrS9ExLbG3x2SHtfwYUo/2X5whOTG3x019/MHEbE9Ig5ExJCkn6rGddcYVv5R\nST+PiMcak2tfd6P1Vdd6qyP8r0g6zfa3bR8l6fuS1tTQxzfYntg4ESPbEyV9T/039PgaSQsbzxdK\neqLGXv5Ivwzb3mxYedW87vpuuPuI6PlD0mUaPuP/O0n/UEcPTfqaIen1xuOtunuT9LCGdwP/V8Pn\nRm6U9GeSnpf0rqT/lDS5j3r7dw0P5f6GhoM2tabe5mp4l/4NSRsbj8vqXneFvmpZb9zeCyTFCT8g\nKcIPJEX4gaQIP5AU4QeSIvxAUoQfSOr/AH6evjIXWuv8AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f580f9626a0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot one example\n",
    "print(train_data.train_data.size())     # (60000, 28, 28)\n",
    "print(train_data.train_labels.size())   # (60000)\n",
    "plt.imshow(train_data.train_data[0].numpy(), cmap='gray')\n",
    "plt.title('%i' % train_data.train_labels[0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Data Loader for easy mini-batch return in training\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# convert test data into Variable, pick 2000 samples to speed up testing\n",
    "test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())\n",
    "test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255.   # shape (2000, 28, 28) value in range(0,1)\n",
    "test_y = test_data.test_labels.numpy().squeeze()[:2000]    # covert to numpy array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(RNN, self).__init__()\n",
    "\n",
    "        self.rnn = nn.LSTM(         # if use nn.RNN(), it hardly learns\n",
    "            input_size=INPUT_SIZE,\n",
    "            hidden_size=64,         # rnn hidden unit\n",
    "            num_layers=1,           # number of rnn layer\n",
    "            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)\n",
    "        )\n",
    "\n",
    "        self.out = nn.Linear(64, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x shape (batch, time_step, input_size)\n",
    "        # r_out shape (batch, time_step, output_size)\n",
    "        # h_n shape (n_layers, batch, hidden_size)\n",
    "        # h_c shape (n_layers, batch, hidden_size)\n",
    "        r_out, (h_n, h_c) = self.rnn(x, None)   # None represents zero initial hidden state\n",
    "\n",
    "        # choose r_out at the last time step\n",
    "        out = self.out(r_out[:, -1, :])\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN (\n",
      "  (rnn): LSTM(28, 64, batch_first=True)\n",
      "  (out): Linear (64 -> 10)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "rnn = RNN()\n",
    "print(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters\n",
    "loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0 | train loss: 2.3127 | test accuracy: 0.14\n",
      "Epoch:  0 | train loss: 1.1182 | test accuracy: 0.56\n",
      "Epoch:  0 | train loss: 0.9374 | test accuracy: 0.68\n",
      "Epoch:  0 | train loss: 0.6279 | test accuracy: 0.75\n",
      "Epoch:  0 | train loss: 0.8004 | test accuracy: 0.75\n",
      "Epoch:  0 | train loss: 0.3407 | test accuracy: 0.88\n",
      "Epoch:  0 | train loss: 0.3342 | test accuracy: 0.89\n",
      "Epoch:  0 | train loss: 0.3200 | test accuracy: 0.91\n",
      "Epoch:  0 | train loss: 0.3430 | test accuracy: 0.92\n",
      "Epoch:  0 | train loss: 0.1234 | test accuracy: 0.93\n",
      "Epoch:  0 | train loss: 0.2714 | test accuracy: 0.92\n",
      "Epoch:  0 | train loss: 0.1043 | test accuracy: 0.93\n",
      "Epoch:  0 | train loss: 0.2893 | test accuracy: 0.93\n",
      "Epoch:  0 | train loss: 0.0635 | test accuracy: 0.94\n",
      "Epoch:  0 | train loss: 0.2412 | test accuracy: 0.94\n",
      "Epoch:  0 | train loss: 0.1203 | test accuracy: 0.92\n",
      "Epoch:  0 | train loss: 0.0807 | test accuracy: 0.94\n",
      "Epoch:  0 | train loss: 0.1307 | test accuracy: 0.94\n",
      "Epoch:  0 | train loss: 0.1166 | test accuracy: 0.95\n"
     ]
    }
   ],
   "source": [
    "# training and testing\n",
    "for epoch in range(EPOCH):\n",
    "    for step, (x, y) in enumerate(train_loader):        # gives batch data\n",
    "        b_x = Variable(x.view(-1, 28, 28))              # reshape x to (batch, time_step, input_size)\n",
    "        b_y = Variable(y)                               # batch y\n",
    "\n",
    "        output = rnn(b_x)                               # rnn output\n",
    "        loss = loss_func(output, b_y)                   # cross entropy loss\n",
    "        optimizer.zero_grad()                           # clear gradients for this training step\n",
    "        loss.backward()                                 # backpropagation, compute gradients\n",
    "        optimizer.step()                                # apply gradients\n",
    "\n",
    "        if step % 50 == 0:\n",
    "            test_output = rnn(test_x)                   # (samples, time_step, input_size)\n",
    "            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n",
    "            accuracy = sum(pred_y == test_y) / float(test_y.size)\n",
    "            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7 2 1 0 4 1 4 9 5 9] prediction number\n",
      "[7 2 1 0 4 1 4 9 5 9] real number\n"
     ]
    }
   ],
   "source": [
    "# print 10 predictions from test data\n",
    "test_output = rnn(test_x[:10].view(-1, 28, 28))\n",
    "pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n",
    "print(pred_y, 'prediction number')\n",
    "print(test_y[:10], 'real number')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
